/******************************************************************************
 * Copyright (c) 2011-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are not permitted.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

/**
 * \file
 * Utility for parsing command line arguments
 */

#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include <vector>

#include <cuda_runtime.h>

namespace cutlass {

/******************************************************************************
 * command_line
 ******************************************************************************/

/**
 * Utility for parsing command line arguments
 */
struct CommandLine {
    std::vector<std::string> keys;
    std::vector<std::string> values;
    std::vector<std::string> args;

    /**
     * Constructor
     */
    CommandLine(int argc, const char** argv) {
        using namespace std;

        for (int i = 1; i < argc; i++) {
            string arg = argv[i];

            if ((arg[0] != '-') || (arg[1] != '-')) {
                args.push_back(arg);
                continue;
            }

            string::size_type pos;
            string key, val;
            if ((pos = arg.find('=')) == string::npos) {
                key = string(arg, 2, arg.length() - 2);
                val = "";
            } else {
                key = string(arg, 2, pos - 2);
                val = string(arg, pos + 1, arg.length() - 1);
            }

            keys.push_back(key);
            values.push_back(val);
        }
    }

    /**
     * Checks whether a flag "--<flag>" is present in the commandline
     */
    bool check_cmd_line_flag(const char* arg_name) const {
        using namespace std;

        for (int i = 0; i < int(keys.size()); ++i) {
            if (keys[i] == string(arg_name))
                return true;
        }
        return false;
    }

    /**
     * Returns number of naked (non-flag and non-key-value) commandline
     * parameters
     */
    template <typename value_t>
    int num_naked_args() const {
        return args.size();
    }

    /**
     * Returns the commandline parameter for a given index (not including flags)
     */
    template <typename value_t>
    void get_cmd_line_argument(int index, value_t& val) const {
        using namespace std;
        if (index < args.size()) {
            istringstream str_stream(args[index]);
            str_stream >> val;
        }
    }

    /**
     * Obtains the boolean value specified for a given commandline parameter
     * --<flag>=<bool>
     */
    void get_cmd_line_argument(const char* arg_name, bool& val,
                               bool _default = true) const {
        val = _default;
        if (check_cmd_line_flag(arg_name)) {
            std::string value;
            get_cmd_line_argument(arg_name, value);

            val = !(value == "0" || value == "false");
        }
    }

    /**
     * Obtains the value specified for a given commandline parameter
     * --<flag>=<value>
     */
    template <typename value_t>
    void get_cmd_line_argument(const char* arg_name, value_t& val) const {
        get_cmd_line_argument(arg_name, val, val);
    }

    /**
     * Obtains the value specified for a given commandline parameter
     * --<flag>=<value>
     */
    template <typename value_t>
    void get_cmd_line_argument(const char* arg_name, value_t& val,
                               value_t const& _default) const {
        using namespace std;

        val = _default;

        for (int i = 0; i < int(keys.size()); ++i) {
            if (keys[i] == string(arg_name)) {
                istringstream str_stream(values[i]);
                str_stream >> val;
            }
        }
    }

    /**
     * Returns the values specified for a given commandline parameter
     * --<flag>=<value>,<value>*
     */
    template <typename value_t>
    void get_cmd_line_arguments(const char* arg_name,
                                std::vector<value_t>& vals,
                                char sep = ',') const {
        using namespace std;

        if (check_cmd_line_flag(arg_name)) {
            // Clear any default values
            vals.clear();

            // Recover from multi-value string
            for (int i = 0; i < keys.size(); ++i) {
                if (keys[i] == string(arg_name)) {
                    string val_string(values[i]);
                    seperate_string(val_string, vals, sep);
                }
            }
        }
    }

    /**
     * Returns the values specified for a given commandline parameter
     * --<flag>=<value>,<value_start:value_end>*
     */
    void get_cmd_line_argument_pairs(
            const char* arg_name,
            std::vector<std::pair<std::string, std::string> >& tokens,
            char delim = ',', char sep = ':') const {
        if (check_cmd_line_flag(arg_name)) {
            std::string value;
            get_cmd_line_argument(arg_name, value);

            tokenize(tokens, value, delim, sep);
        }
    }

    /**
     * Returns a list of ranges specified for a given commandline parameter
     * --<flag>=<key:value>,<key:value>*
     */
    void get_cmd_line_argument_ranges(
            const char* arg_name, std::vector<std::vector<std::string> >& vals,
            char delim = ',', char sep = ':') const {
        std::vector<std::string> ranges;
        get_cmd_line_arguments(arg_name, ranges, delim);

        for (std::vector<std::string>::const_iterator range = ranges.begin();
             range != ranges.end(); ++range) {
            std::vector<std::string> range_vals;
            seperate_string(*range, range_vals, sep);
            vals.push_back(range_vals);
        }
    }

    /**
     * The number of pairs parsed
     */
    int parsed_argc() const { return (int)keys.size(); }

    //-------------------------------------------------------------------------
    // Utility functions
    //-------------------------------------------------------------------------

    /// Tokenizes a comma-delimited list of string pairs delimited by ':'
    static void tokenize(
            std::vector<std::pair<std::string, std::string> >& tokens,
            std::string const& str, char delim = ',', char sep = ':') {
        // Home-built to avoid Boost dependency
        size_t s_idx = 0;
        size_t d_idx = std::string::npos;
        while (s_idx < str.size()) {
            d_idx = str.find_first_of(delim, s_idx);

            size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
            size_t sep_idx = str.find_first_of(sep, s_idx);
            size_t offset = 1;
            if (sep_idx == std::string::npos || sep_idx >= end_idx) {
                sep_idx = end_idx;
                offset = 0;
            }

            std::pair<std::string, std::string> item(
                    str.substr(s_idx, sep_idx - s_idx),
                    str.substr(sep_idx + offset, end_idx - sep_idx - offset));

            tokens.push_back(item);
            s_idx = end_idx + 1;
        }
    }

    /// Tokenizes a comma-delimited list of string pairs delimited by ':'
    static void tokenize(std::vector<std::string>& tokens,
                         std::string const& str, char delim = ',',
                         char sep = ':') {
        typedef std::vector<std::pair<std::string, std::string> > TokenVector;
        typedef TokenVector::const_iterator token_iterator;

        std::vector<std::pair<std::string, std::string> > token_pairs;
        tokenize(token_pairs, str, delim, sep);
        for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end();
             ++tok) {
            tokens.push_back(tok->first);
        }
    }

    template <typename value_t>
    static void seperate_string(std::string const& str,
                                std::vector<value_t>& vals, char sep = ',') {
        std::istringstream str_stream(str);
        std::string::size_type old_pos = 0;
        std::string::size_type new_pos = 0;

        // Iterate <sep>-delimited values
        value_t val;
        while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
            if (new_pos != old_pos) {
                str_stream.width(new_pos - old_pos);
                str_stream >> val;
                vals.push_back(val);
            }

            // skip over delimiter
            str_stream.ignore(1);
            old_pos = new_pos + 1;
        }

        // Read last value
        str_stream >> val;
        vals.push_back(val);
    }
};

}  // namespace cutlass
